BG/NBD Model - Stan Implementation

Author

Abdullah Mahmood

Published

March 20, 2025

In this notebook we show how to fit a BG/NBD model in Stan. We compare the results with the lifetimes package. The model is presented in the paper: Fader, P. S., Hardie, B. G., & Lee, K. L. (2005). “Counting your customers” the easy way: An alternative to the Pareto/NBD model. Marketing science, 24(2), 275-284.

Imports

Packages

Code
from utils import CDNOW, Stan, StanQuap
from scipy.optimize import minimize
import arviz as az
import polars as pl
import numpy as np


import matplotlib.pyplot as plt
%config InlineBackend.figure_formats = ['svg']

Data

data = CDNOW(master=False, calib_p=273) # 39 week calibration period
rfm_data = data.rfm_summary().select("P1X", "t_x", "T").collect().to_numpy()
p1x, t_x, T = rfm_data[:, 0], rfm_data[:, 1], rfm_data[:, 2]

Recall from the paper the following definitions:

  • p1x represents the number of repeat purchases the customer has made. This means that it’s one less than the total number of purchases. This is actually slightly wrong. It’s the count of time periods the customer had a purchase in. So if using days as units, then it’s the count of days the customer had a purchase on.
  • T represents the age of the customer in whatever time units chosen (weekly, in the above dataset). This is equal to the duration between a customer’s first purchase and the end of the period under study.
  • t_x represents the age of the customer when they made their most recent purchases. This is equal to the duration between a customer’s first purchase and their latest purchase. (Thus if they have made only 1 purchase, the recency is 0.)

Model Specification

The BG/NBD model is a probabilistic model that describes the buying behavior of a customer in the non-contractual setting. It is based on the following assumptions for each customer:

Frequency Process

  1. While active, the time between transactions is distributed exponential with transaction rate, i.e.

\[ f(t_{j} \mid t_{j-1}; \lambda) = \lambda \exp(-\lambda (t_{j} - t_{j - 1})), \quad t_{j} \geq t_{j - 1} \geq 0 \]

  1. Heterogeneity in \(\lambda\) follows a gamma distribution with pdf

\[ f(\lambda \mid r, \alpha) = \frac{\alpha^{r}\lambda^{r - 1}\exp(-\lambda \alpha)}{\Gamma(r)}, \quad \\lambda > 0 \]

Dropout Process

  1. After any transaction, a customer becomes inactive with probability \(p\).

  2. Heterogeneity in \(p\) follows a beta distribution with pdf

\[ f(p \mid a, b) = \frac{\Gamma(a + b)}{\Gamma(a) \Gamma(b)} p^{a - 1}(1 - p)^{b - 1}, \quad 0 \leq p \leq 1 \]

  1. The transaction rate \(\lambda\) and the dropout probability \(p\) vary independently across customers.

Instead of estimating \(\lambda\) and \(p\) for each specific customer, we do it for a randomly chosen customer, i.e. we work with the expected values of the parameters. Hence, we are interesting in finding the posterior distribution of the parameters \(r\), \(\alpha\), \(a\), and \(b\).

Analytical MLE

Standard SciPy Implementation

Code
import numpy as np
from scipy.special import gammaln, hyp2f1, gamma, factorial

def bgnbd_ll(x, p1x, t_x, T):
    r, alpha, a, b = x

    # Logarithm calculations with numerical stability
    log_alpha = np.log(np.clip(alpha, 1e-10, None))  # Avoid log(0) by clipping to a small value
    log_alpha_t_x = np.log(np.clip(alpha + t_x, 1e-10, None))

    # Components of the log-likelihood
    D_1 = (
        gammaln(r + p1x)
        - gammaln(r)
        + gammaln(a + b)
        + gammaln(b + p1x)
        - gammaln(b)
        - gammaln(a + b + p1x)
    )
    D_2 = r * log_alpha - (r + p1x) * log_alpha_t_x
    C_3 = ((alpha + t_x) / (alpha + T)) ** (r + p1x)
    C_4 = a / (b + p1x - 1)

    # Handle cases where p1x > 0 and apply log to valid values
    log_term = np.log(np.clip(C_3 + C_4, 1e-10, None))
    result = D_1 + D_2 + np.where(p1x > 0, log_term, np.log(np.clip(C_3, 1e-10, None)))

    return -np.sum(result)

def bgnbd_est():
    guess={'r': 0.01, 'alpha': 0.01, 'a': 0.01, 'b': 0.01}
    # Bounds for the optimization
    bnds = [(1e-6, np.inf) for _ in range(4)]

    # Optimization using minimize
    return minimize(
        bgnbd_ll,
        x0=list(guess.values()),
        method="BFGS",
        args=(p1x, t_x, T)
    )

result = bgnbd_est()
r, alpha, a, b = result.x
ll = result.fun

print(
f"""r = {r:0.4f}
α = {alpha:0.4f}
a = {a:0.4f}
b = {b:0.4f}
Log-Likelihood = {-ll:0.4f}"""
)

index = index=["r", "α", "a", "b"]
scipy_params = result.x
var_mat = result.hess_inv
se = np.sqrt(np.diag(var_mat))
plo = scipy_params - 1.96 * se
phi = scipy_params + 1.96 * se
pl.DataFrame(
    {
        "Parameter": index,
        "Coef": scipy_params,
        "SE (Coef)": se,
        "5.5%": plo,
        "94.5%": phi,
    }
)
r = 0.2426
α = 4.4136
a = 0.7929
b = 2.4259
Log-Likelihood = -9582.4292
shape: (4, 5)
Parameter Coef SE (Coef) 5.5% 94.5%
str f64 f64 f64 f64
"r" 0.242595 0.012417 0.218258 0.266931
"α" 4.413603 0.476973 3.478735 5.348471
"a" 0.792923 0.180615 0.438918 1.146928
"b" 2.425909 0.777435 0.902136 3.949683

Cameron Davidson-Pilon’s lifetimes Implementation

Source: BetaGeoFitter, fit function

Code
import autograd.numpy as np
from autograd.scipy.special import gammaln, beta, gamma
from autograd import value_and_grad, hessian

def negative_log_likelihood(log_params, freq, rec, T):
    params = np.exp(log_params)
    r, alpha, a, b = params

    A_1 = gammaln(r + freq) - gammaln(r) + r * np.log(alpha)
    A_2 = gammaln(a + b) + gammaln(b + freq) - gammaln(b) - gammaln(a + b + freq)
    A_3 = -(r + freq) * np.log(alpha + T)
    A_4 = np.log(a) - np.log(b + np.maximum(freq, 1) - 1) - (r + freq) * np.log(rec + alpha)

    max_A_3_A_4 = np.maximum(A_3, A_4)

    ll = (A_1 + A_2 + np.log(np.exp(A_3 - max_A_3_A_4) + np.exp(A_4 - max_A_3_A_4) * (freq > 0)) + max_A_3_A_4)

    return -ll.sum()

def BetaGeoFitter(guess={'r': 0.1, 'alpha': 0.1, 'a': 0.0, 'b': 0.1}):
    
    # Bounds for the optimization
    # bnds = [(1e-6, np.inf) for _ in range(4)]

    # Optimization using minimize
    return minimize(
        value_and_grad(negative_log_likelihood),
        jac=True,
        method=None,
        args=(p1x, t_x, T),
        tol=1e-7, 
        x0=list(guess.values()),
        options={'disp': True}
    )

result = BetaGeoFitter()
r, alpha, a, b = np.exp(result.x)
ll = result.fun

print(
f"""r = {r:0.4f}
α = {alpha:0.4f}
a = {a:0.4f}
b = {b:0.4f}
Log-Likelihood = {-ll:0.4f}"""
)

index = index=["r", "α", "a", "b"]
lifetimes_params = np.exp(result.x)
hessian_mat = hessian(negative_log_likelihood)(result.x, p1x, t_x, T)
var_mat = (lifetimes_params ** 2) * np.linalg.inv(hessian_mat) # Variance-Covariance Matrix
se = np.sqrt(np.diag(var_mat))  # Standard Error
plo = lifetimes_params - 1.96 * se
phi = lifetimes_params + 1.96 * se
pl.DataFrame(
    {
        "Parameter": index,
        "Coef": lifetimes_params,
        "SE (Coef)": se,
        "5.5%": plo,
        "94.5%": phi,
    }
)
Optimization terminated successfully.
         Current function value: 9582.429207
         Iterations: 21
         Function evaluations: 26
         Gradient evaluations: 26
r = 0.2426
α = 4.4136
a = 0.7929
b = 2.4259
Log-Likelihood = -9582.4292
shape: (4, 5)
Parameter Coef SE (Coef) 5.5% 94.5%
str f64 f64 f64 f64
"r" 0.242595 0.012557 0.217982 0.267207
"α" 4.413602 0.378224 3.672283 5.154921
"a" 0.792922 0.185734 0.428884 1.15696
"b" 2.425907 0.705414 1.043295 3.808519

Stan Model

Standard Parameters

import numpy as np

stan_code = '''
data {
    int<lower=0> N;               // Number of customers
    array[N] int<lower=0> X;      // Number of transactions per customer
    vector<lower=0>[N] T;         // Total observation time per customer
    vector<lower=0>[N] Tx;        // Time of last transaction (0 if X=0)
}

parameters {
    real<lower=0> r;                   // gamma shape (r)
    real<lower=0> alpha;               // gamma scale (alpha)
    real<lower=0, upper=5> a;          // beta shape 1 (a)
    real<lower=0, upper=5> b;          // beta shape 2 (b)
}

model {
    // Weakly informative priors on log parameters
    r ~ weibull(2, 1);
    alpha ~ weibull(2, 10);
    a ~ uniform(0, 5);
    b ~ uniform(0, 5);

    for (n in 1:N) {
        int x = X[n];
        real tx = Tx[n];
        real t = T[n];
    
        if (x == 0) {
              // Likelihood for X=0: (alpha/(alpha + t))^r
              target += r * (log(alpha) - log(alpha + t));
        } else {
              // Term 1: B(a, b + x)/B(a, b) * Γ(r + x)/Γ(r) * (alpha/(alpha + t))^(r + x)
              real beta_term1 = lbeta(a, b + x) - lbeta(a, b);
              real gamma_term = lgamma(r + x) - lgamma(r);
              real term1 = gamma_term + beta_term1 + r * log(alpha) - (r + x) * log(alpha + t);
            
              // Term 2: B(a + 1, b + x - 1)/B(a, b) * Γ(r + x)/Γ(r) * (alpha/(alpha + tx))^(r + x)
              real beta_term2 = lbeta(a + 1, b + x - 1) - lbeta(a, b);
              real term2 = gamma_term + beta_term2 + r * log(alpha) - (r + x) * log(alpha + tx);
            
              // Log-sum-exp for numerical stability
              target += log_sum_exp(term1, term2);
        }
    }
}
'''
Code
data = {
    'N': len(p1x),
    'X': p1x.flatten().astype(int).tolist(),
    'T': T.flatten().tolist(),
    'Tx': t_x.flatten().tolist()
}

stan_model = StanQuap(stan_file='stan_models/bg-nbd', stan_code=stan_code, data=data, algorithm='LBFGS', jacobian=False, tol_rel_grad=1e-7, iter=5000)

index = index=["r", "α", "a", "b"]
params = np.array(list(stan_model.opt_model.stan_variables().values()))
var_mat = var_mar = stan_model.vcov_matrix() # Variance-Covariance Matrix
se = np.sqrt(np.diag(var_mat))  # Standard Error
plo = params - 1.96 * se
phi = params + 1.96 * se
pl.DataFrame(
    {
        "Parameter": index,
        "Coef": params,
        "SE (Coef)": se,
        "5.5%": plo,
        "94.5%": phi,
    }
)
shape: (4, 5)
Parameter Coef SE (Coef) 5.5% 94.5%
str f64 f64 f64 f64
"r" 0.243678 0.012593 0.218995 0.268361
"α" 4.44677 0.379503 3.702944 5.190596
"a" 0.792205 0.185778 0.428081 1.156329
"b" 2.42758 0.706727 1.042396 3.812764
Code
x = [stan_model.opt_model.optimized_params_pd['r'][0],
     stan_model.opt_model.optimized_params_pd['alpha'][0],
     stan_model.opt_model.optimized_params_pd['a'][0],
     stan_model.opt_model.optimized_params_pd['b'][0]]

print("Log-Likelihood:", bgnbd_ll(x, p1x, t_x, T))
# print(negative_log_likelihood(np.log(np.array(x)), p1x, t_x, T))
Log-Likelihood: 9582.433433869468

Modified Parameters

stan_code = '''
data {
    int<lower=0> N;               // Number of customers
    array[N] int<lower=0> X;      // Number of transactions per customer
    vector<lower=0>[N] T;         // Total observation time per customer
    vector<lower=0>[N] Tx;        // Time of last transaction (0 if X=0)
}

parameters {
    real<lower=0> r;                         // Shape parameter for the Gamma prior on purchase rate
    real<lower=0> alpha;                     // Scale parameter for purchase rate
    real<lower=0, upper=1> phi_dropout;      // Mixture weight for dropout process (Uniform prior)
    real<lower=1> kappa_dropout;             // Scale parameter for dropout (Pareto prior)
}

transformed parameters {
    real a = phi_dropout * kappa_dropout;       // Dropout shape parameter (controls early dropout likelihood)
    real b = (1 - phi_dropout) * kappa_dropout; // Dropout scale parameter (controls later dropout likelihood)
}

model {
    // Priors:
    r ~ weibull(2, 1);                // Prior on r (purchase rate shape parameter)
    alpha ~ weibull(2, 10);           // Prior on alpha (purchase rate scale parameter)
    phi_dropout ~ uniform(0,1);       // Mixture component for dropout process
    kappa_dropout ~ pareto(1,1);      // Scale of dropout process

    for (n in 1:N) {
        int x = X[n];                 // Number of transactions for customer n
        real tx = Tx[n];              // Time of last transaction
        real t = T[n];                // Total observation time

        if (x == 0) {
            // Likelihood for customers with zero transactions:
            // Probability of no purchases during (0, T): (alpha/(alpha + t))^r
            // Likelihood for X=0: (alpha/(alpha + t))^r
            target += r * (log(alpha) - log(alpha + t));
        } else {
            // Term 1: Probability of surviving until T and making x purchases
            // Term 1: B(a, b + x)/B(a, b) * Γ(r + x)/Γ(r) * (alpha/(alpha + t))^(r + x)
            real beta_term1 = lbeta(a, b + x) - lbeta(a, b);  // Beta function term
            real gamma_term = lgamma(r + x) - lgamma(r);       // Gamma function term
            real term1 = gamma_term + beta_term1 + r * log(alpha) - (r + x) * log(alpha + t);
            
            // Term 2: Probability of surviving until Tx, then dropping out
            // Term 2: B(a + 1, b + x - 1)/B(a, b) * Γ(r + x)/Γ(r) * (alpha/(alpha + tx))^(r + x)
            real beta_term2 = lbeta(a + 1, b + x - 1) - lbeta(a, b);
            real term2 = gamma_term + beta_term2 + r * log(alpha) - (r + x) * log(alpha + tx);
            
            // Log-sum-exp for numerical stability
            target += log_sum_exp(term1, term2);
        }
    }
}
'''
Code
data = {
    'N': len(p1x),
    'X': p1x.flatten().astype(int).tolist(),
    'T': T.flatten().tolist(),
    'Tx': t_x.flatten().tolist()
}

stan_model = StanQuap(stan_file='stan_models/bg-nbd-1', stan_code=stan_code, data=data, algorithm='LBFGS', jacobian=False, tol_rel_grad=1e-7, iter=5000, generated_var=['a', 'b'])

index = index=["r", "α", 'phi', 'kappa', "a", "b"]
MAP_params = np.array(list(stan_model.opt_model.stan_variables().values()))
var_mat = var_mar = stan_model.vcov_matrix() # Variance-Covariance Matrix
se = np.sqrt(np.diag(var_mat))  # Standard Error
se =  np.concatenate((se, np.array([se[-2] * se[-1] , (1 - se[-2]) * se[-1]])))
plo = MAP_params - 1.96 * se
phi = MAP_params + 1.96 * se
pl.DataFrame(
    {
        "Parameter": index,
        "Coef": MAP_params,
        "SE (Coef)": se,
        "5.5%": plo,
        "94.5%": phi,
    }
)
shape: (6, 5)
Parameter Coef SE (Coef) 5.5% 94.5%
str f64 f64 f64 f64
"r" 0.243722 0.0126 0.219026 0.268418
"α" 4.44371 0.379234 3.700412 5.187008
"phi" 0.252241 0.019638 0.21375 0.290732
"kappa" 2.79763 0.717922 1.390503 4.204757
"a" 0.705676 0.014099 0.678043 0.733309
"b" 2.09195 0.703823 0.712456 3.471444
Code
x = [stan_model.opt_model.optimized_params_pd['r'][0],
     stan_model.opt_model.optimized_params_pd['alpha'][0],
     stan_model.opt_model.optimized_params_pd['a'][0],
     stan_model.opt_model.optimized_params_pd['b'][0]]

print("Log-Likelihood:", bgnbd_ll(x, p1x, t_x, T))
# print(negative_log_likelihood(np.log(np.array(x)), p1x, t_x, T))
Log-Likelihood: 9582.570577328574

MCMC Model Fitting

mcmc = stan_model.stan_model.sample(data=data)
inf_data = az.from_cmdstanpy(mcmc)
print(mcmc.diagnose())
                                                                                                                                                                                                                                                                                                                                
Checking sampler transitions treedepth.
Treedepth satisfactory for all transitions.

Checking sampler transitions for divergences.
No divergent transitions found.

Checking E-BFMI - sampler transitions HMC potential energy.
E-BFMI satisfactory.

Rank-normalized split effective sample size satisfactory for all parameters.

Rank-normalized split R-hat values satisfactory for all parameters.

Processing complete, no problems detected.

Model Summary

Code
mcmc.summary()
Mean MCSE StdDev MAD 5% 50% 95% ESS_bulk ESS_tail R_hat
lp__ -9587.790000 0.035648 1.478200 1.275040 -9590.620000 -9587.470000 -9586.080000 1694.05 2234.96 1.00012
r 0.245175 0.000279 0.012767 0.012615 0.224980 0.244719 0.266439 2124.19 2284.57 1.00105
alpha 4.507340 0.008621 0.380884 0.381873 3.909960 4.497960 5.159060 1947.31 2135.86 1.00147
phi_dropout 0.247511 0.000412 0.020002 0.019909 0.214676 0.247144 0.281065 2371.81 2409.26 1.00151
kappa_dropout 3.211950 0.020720 0.962487 0.812020 2.021490 3.030890 4.966390 2296.31 2154.48 1.00120
a 0.783547 0.003870 0.196652 0.173965 0.527484 0.754011 1.141310 2698.30 2366.67 1.00130
b 2.428400 0.016952 0.774778 0.646940 1.483030 2.275150 3.878010 2230.57 2033.92 1.00173

Model Trace Plot

Code
axes = az.plot_trace(
    data=inf_data,
    compact=True,
    kind="rank_bars",
    backend_kwargs={"figsize": (12, 9), "layout": "constrained"},
)
plt.gcf().suptitle("BG/NBD Model Trace", fontsize=18, fontweight="bold")
plt.tight_layout();
/var/folders/0s/z9xp988n3j78zfjwg3y616x00000gn/T/ipykernel_88373/1157014386.py:8: UserWarning:

The figure layout has changed to tight

Model Posterior Plot

Code
fig, axes = plt.subplots(2, 2, figsize=(12, 9), sharex=False, sharey=False)

axes = axes.flatten()

for i, var_name in enumerate(["r", "alpha", "a", "b"]):
    ax = axes[i]
    az.plot_posterior(
        inf_data.posterior[var_name].values.flatten(),
        color="C0",
        point_estimate="mean",
        ax=ax,
        label="MCMC",
    )
    ax.axvline(x=stan_model.opt_model.stan_variable(var_name), color="C1", linestyle="--", label="Stan MAP")
    ax.axvline(x=lifetimes_params[i], color="C2", linestyle="--", label="Lifetimes")
    ax.axvline(x=scipy_params[i], color="C3", linestyle="--", label="SciPy")
    ax.legend(loc="upper right")
    ax.set_title(var_name)

plt.gcf().suptitle("BG/NBD Model Parameters", fontsize=18, fontweight="bold")
plt.tight_layout();

The r and alpha purchase rate parameters are quite similar for all three models, but the a and b dropout parameters are better approximated with the phi_dropout and kappa_dropout parameters when fitted with MCMC.

Prior and Posterior Predictive Checks

PPCs allow us to check the efficacy of our priors, and the performance of the fitted posteriors.

Prior Predictive Check

Let’s see how the model performs in a prior predictive check, where we sample from the default priors before fitting the model:

stan_prior_code = '''
data {
    int<lower=0> N;
    vector<lower=0>[N] T;
}

generated quantities {
    real r = weibull_rng(2, 1);
    real alpha = weibull_rng(2, 10);
    real phi_dropout = uniform_rng(0, 1);
    real kappa_dropout = pareto_rng(1, 1);

    real a = phi_dropout * kappa_dropout;
    real b = (1 - phi_dropout) * kappa_dropout;

    array[N] int X_rep;
    vector[N] Tx_rep;

    for (n in 1:N) {
        real lambda_n = gamma_rng(r, alpha);
        real p_n = beta_rng(a, b);
        real current_time = 0;
        int x = 0;
        real tx = 0;
        int active = 1;

        while (active && current_time < T[n]) {
            real wait = exponential_rng(lambda_n);
            if (current_time + wait > T[n]) {
                break;
            } else {
                current_time += wait;
                x += 1;
                tx = current_time;
                if (bernoulli_rng(p_n)) {
                    active = 0;
                }
            }
        }
        X_rep[n] = x;
        Tx_rep[n] = tx;
    }
}
'''
Code
data = {
    'N': len(p1x),
    'T': T.flatten().tolist(),
}

stan_model = Stan(stan_file='stan_models/bg-nbd-prior', stan_code=stan_prior_code)
prior_fit = stan_model.sample(data=data)
                                                                                                                                                                                                                                                                                                                                
Code
prior_samples = prior_fit.stan_variable('X_rep').astype(int)

# Compute prior frequency distribution
max_purch = 9
prior_freq_counts = np.zeros((prior_samples.shape[0], max_purch + 1))
for i in range(prior_samples.shape[0]):
    counts = np.bincount(prior_samples[i], minlength=max_purch + 1)
    prior_freq_counts[i] = counts[:max_purch + 1] / data['N']

mean_frequency = prior_freq_counts.mean(axis=0)
hdi = az.hdi(prior_freq_counts, hdi_prob=0.89)

observed_counts = np.bincount(p1x.flatten().astype(int), minlength=max_purch + 1)
observed_frequency = observed_counts[:max_purch + 1] / data['N']

plt.clf()
plt.bar(np.arange(max_purch+1)+0.2, mean_frequency, width=0.4, label='Estimated', color='white', edgecolor='black', linewidth=0.5, yerr=[mean_frequency - hdi[:,0], hdi[:,1] - mean_frequency])
plt.bar(np.arange(max_purch+1)-0.2,  observed_frequency, 0.4, label='Observed', color='black')
plt.xlabel(f"Number of Repeat Purchases (0-{max_purch+1}+)")
plt.ylabel("% of Customer Population")
plt.title("Prior Predictive Check: Customer Purchase Frequency")
plt.xticks(ticks=np.arange(max_purch+1), labels=[f'{i}' if i < max_purch else f'{max_purch+1}+' for i in range(max_purch+1)])
plt.legend();

Posterior Predictive Check

stan_post_code = '''
data {
    int<lower=0> N;               
    array[N] int<lower=0> X;      
    vector<lower=0>[N] T;         
    vector<lower=0>[N] Tx;        
}

parameters {
    real<lower=0> r;                         
    real<lower=0> alpha;                     
    real<lower=0, upper=1> phi_dropout;      
    real<lower=1> kappa_dropout;             
}

transformed parameters {
    real a = phi_dropout * kappa_dropout;       
    real b = (1 - phi_dropout) * kappa_dropout; 
}

model {
    // Priors:
    r ~ weibull(2, 1);                
    alpha ~ weibull(2, 10);           
    phi_dropout ~ uniform(0,1);       
    kappa_dropout ~ pareto(1,1);      

    for (n in 1:N) {
        int x = X[n];                 
        real tx = Tx[n];              
        real t = T[n];                

        if (x == 0) {
            target += r * (log(alpha) - log(alpha + t));
        } else {
            real beta_term1 = lbeta(a, b + x) - lbeta(a, b);  
            real gamma_term = lgamma(r + x) - lgamma(r);       
            real term1 = gamma_term + beta_term1 + r * log(alpha) - (r + x) * log(alpha + t);
            
            real beta_term2 = lbeta(a + 1, b + x - 1) - lbeta(a, b);
            real term2 = gamma_term + beta_term2 + r * log(alpha) - (r + x) * log(alpha + tx);
            
            target += log_sum_exp(term1, term2);
        }
    }
}

generated quantities {
    array[N] int X_rep;
    vector[N] Tx_rep;

    for (n in 1:N) {
        real lambda_n = gamma_rng(r, alpha);
        real p_n = beta_rng(a, b);
        real current_time = 0;
        int x = 0;
        real tx = 0;
        int active = 1;

        while (active && current_time < T[n]) {
            real wait = exponential_rng(lambda_n);
            if (current_time + wait > T[n]) {
                break;
            } else {
                current_time += wait;
                x += 1;
                tx = current_time;
                if (bernoulli_rng(p_n)) {
                    active = 0;
                }
            }
        }
        X_rep[n] = x;
        Tx_rep[n] = tx;
    }
}
'''
Code
data = {
    'N': len(p1x),
    'X': p1x.flatten().astype(int).tolist(),
    'T': T.flatten().tolist(),
    'Tx': t_x.flatten().tolist()
}

stan_model = Stan(stan_file='stan_models/bg-nbd-post', stan_code=stan_post_code)
post_preds = stan_model.sample(data=data)
                                                                                                                                                                                                                                                                                                                                
Code
posterior = az.from_cmdstanpy(post_preds, posterior_predictive=['X_rep', 'Tx_rep'])
posterior_samples = post_preds.stan_variable('X_rep').astype(int) 

# Compute frequency distribution
max_purch = 9
frequency_counts = np.zeros((posterior_samples.shape[0], max_purch + 1))
for i in range(posterior_samples.shape[0]):
    counts = np.bincount(posterior_samples[i], minlength=max_purch + 1)
    frequency_counts[i] = counts[:max_purch + 1] / data['N']

# Calculate mean and HDI
mean_frequency = frequency_counts.mean(axis=0)
hdi = az.hdi(frequency_counts, hdi_prob=0.89)

observed_counts = np.bincount(p1x.flatten().astype(int), minlength=max_purch + 1)
observed_frequency = observed_counts[:max_purch + 1] / data['N']

plt.clf()
plt.bar(np.arange(max_purch+1)+0.2, mean_frequency, width=0.4, label='Estimated', color='white', edgecolor='black', linewidth=0.5, yerr=[mean_frequency - hdi[:,0], hdi[:,1] - mean_frequency])
plt.bar(np.arange(max_purch+1)-0.2, observed_frequency, width=0.4, label='Observed', color='black')
plt.xlabel("Number of Repeat Purchases (0-10+)")
plt.ylabel("% of Customer Population")
plt.title("Posterior Predictive Check: Customer Purchase Frequency")
plt.xticks(ticks=np.arange(max_purch+1), labels=[f'{i}' if i < max_purch else f'{max_purch+1}+' for i in range(max_purch+1)])
plt.legend();